Mouse MOp (MERFISH)¶
Importing¶
In [1]:
import enclus
import scanpy as sc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')
Preprocessing scRNA-ref¶
In [2]:
sc_data = sc.read('datasets/Ref_snRNA_mop_qc3_2Kgenes.h5ad') #20 cell types (13516, 21158)
cell_type_column = 'subclass_label'
sc_data.X.max()
Out[2]:
np.float32(7.5840874)
In [3]:
# Extract information to identify differentially expressed (DE) genes
st_data = sc.read('datasets/MERFISH_mop.h5ad')
# Load reference single-cell data
sc_data = sc.read('datasets/Ref_snRNA_mop_qc3_2Kgenes.h5ad')
st_data.obs_names_make_unique()
st_data.var_names_make_unique()
st_data.obs = st_data.obs.rename(columns={'X': 'x', 'Y': 'y'})
st_data.obs = st_data.obs.rename(columns={'subclass': 'cell_type'})
sc_data.obs = sc_data.obs.rename(columns={'subclass_label': 'cell_type'})
# Extract 'x' and 'y' columns from st_data.obs
spatial_data = st_data.obs[['x', 'y']].values
# Store the extracted x and y columns into st_data.obsm['spatial']
st_data.obsm['spatial'] = spatial_data
from scipy.sparse.csc import csc_matrix
from scipy.sparse.csr import csr_matrix
if isinstance(sc_data.X, csc_matrix) or isinstance(sc_data.X, csr_matrix):
sc_data.X = sc_data.X.toarray()
type(st_data.X), type(sc_data.X)
# Extract cell type array
sp_adata_ct = np.array(st_data.obs['cell_type'])
# Pre-process spatial data: merge certain subclass labels
sp_adata_ct = np.array([_.replace('L4/5 IT', 'L5 IT') for _ in sp_adata_ct]) # including 'SMC', 'L6 IT Car3', 'L4/5 IT'
st_data.obs['cell_type'] = sp_adata_ct
# Find overlapping cell types between spatial and reference data
overlap_ct = np.array(list(set(np.unique(st_data.obs['cell_type'])) &
set(np.unique(sc_data.obs['cell_type']))))
st_data = st_data[st_data.obs['cell_type'].isin(overlap_ct)].copy()
sc_data = sc_data[sc_data.obs['cell_type'].isin(overlap_ct)].copy()
merfish_data = st_data
ref_data = sc_data
# Extract expression matrix (if stored in .raw)
counts_merfish = merfish_data.X
# Extract spatial coordinates (assume stored in obsm['spatial'])
coords_merfish = merfish_data.obsm['spatial']
# Compute nUMI (total counts per spot), convert to 1D array
nUMI_merfish = counts_merfish.sum(axis=1)
# Extract reference expression matrix
counts_ref = ref_data.X
# Extract cell type information (assume stored in .obs['cell_type'])
cell_types_ref = ref_data.obs['cell_type']
# Extract nUMI information (assume stored in .obs['nUMI'])
nUMI_ref = ref_data.obs['nUMI']
# Save MERFISH counts and coordinates data
merfish_counts_df = pd.DataFrame(
counts_merfish,
index=merfish_data.obs_names,
columns=merfish_data.var_names
)
merfish_counts_df.to_csv('MERFISH_counts.csv')
coords_df = pd.DataFrame(
coords_merfish,
index=merfish_data.obs_names,
columns=['x', 'y']
)
coords_df.to_csv('MERFISH_coords.csv')
# Save reference data: counts, cell types, and nUMI
ref_counts_df = pd.DataFrame(
counts_ref,
index=ref_data.obs_names,
columns=ref_data.var_names
)
ref_counts_df.to_csv('Ref_counts.csv')
meta_data_ref_df = pd.DataFrame({
'barcode': ref_data.obs_names,
'cluster': cell_types_ref,
'nUMI': nUMI_ref
})
meta_data_ref_df.to_csv('Ref_meta_data.csv')
Preprocessing MERFISH data¶
We used single - cell data to identify sets of cell type marker genes and highly variable genes, and removed redundant cell types in spatial transcriptomics.
In [ ]:
sc_data.raw = sc_data.copy()
cell_type_column = 'subclass_label'
markers_df = pd.DataFrame(sc_data.uns["rank_genes_groups"]["names"]).iloc[0:30, :]
markers = list(np.unique(markers_df.melt().value.values))
markers = list(set(sc_data.var.loc[sc_data.var['highly_variable']==1].index)|set(markers)) # highly variable genes 1931 + cell type marker genes
print(len(markers))
st_data = sc.read('./datasets/MERFISH_mop.h5ad') #23 cell types (5551, 254)
st_data.obs_names_make_unique()
st_data.var_names_make_unique()
st_data.obs = st_data.obs.rename(columns = {'X':'x', 'Y':'y'})
st_data.obs = st_data.obs.rename(columns = {'subclass':'cell_type'})
sc_data.obs = sc_data.obs.rename(columns = {'subclass_label':'cell_type'})
# Extract 'x' and 'y' columns from st_data.obs
spatial_data = st_data.obs[['x', 'y']].values
# Store the extracted x and y columns into st_data.obsm['spatial']
st_data.obsm['spatial'] = spatial_data
sc.pp.log1p(st_data)
print(st_data.X.max())
merfish_genes = st_data.var.index.values.tolist()
add_genes = 'Nos1ap Erbb4 Atp2b4 Adamts3 Cdh4 Celf2 Crispld1 Esrrg Htr4 Kcnh5 Prkg1 3110035E14Rik Garnl3 Pvalb Cplx3 Fam84b Slc17a6 Tenm3 Opalin Cdh12 Enpp6 Kcng1 Cux2 Otof Rorb Rspo1 Sulf2 Fezf2 Osr1'.split() # some important genes that we interested
# markers = markers+merfish_genes+add_genes+ligand_recept
markers = np.unique(markers+add_genes)
print("Markers:",len(markers))
from scipy.sparse.csc import csc_matrix
from scipy.sparse.csr import csr_matrix
if isinstance(sc_data.X, csc_matrix) or isinstance(sc_data.X, csr_matrix):
sc_data.X = sc_data.X.toarray()
type(st_data.X),type(sc_data.X)
# fig, axs = plt.subplots(1, 1, figsize=(10, 10))
# sc.pl.umap(
# sc_data, color=cell_type_column, size=15, frameon=False, show=False, ax=axs,legend_loc='on data'
# )
# plt.tight_layout()
sp_adata_ct = np.array(st_data.obs['cell_type'])
# pre-process spatial data
sp_adata_ct = np.array([_.replace('L4/5 IT', 'L5 IT') for _ in sp_adata_ct]) #drop 'SMC', 'L6 IT Car3', 'L4/5 IT'
st_data.obs['cell_type'] = sp_adata_ct
overlap_ct = np.array(list(set(np.unique(st_data.obs['cell_type'])) & set(np.unique(sc_data.obs['cell_type']))))
st_data = st_data[st_data.obs['cell_type'].isin(overlap_ct)].copy()
sc_data = sc_data[sc_data.obs['cell_type'].isin(overlap_ct)].copy()
# st_data.write('./datasets/mop/processed_MERFISH_mop.h5ad')
# sc_data.write('./datasets/mop/processed_snRNA_mop.h5ad')
1931 5.3889804 Markers: 1938
Load processed data¶
In [3]:
sc_data = sc.read_h5ad('./datasets/mop/processed_snRNA_mop.h5ad')
st_data = sc.read_h5ad('./datasets/mop/processed_MERFISH_mop.h5ad')
print(sc_data,st_data)
add_genes = 'Nos1ap Erbb4 Atp2b4 Adamts3 Cdh4 Celf2 Crispld1 Esrrg Htr4 Kcnh5 Prkg1 3110035E14Rik Garnl3 Pvalb Cplx3 Fam84b Slc17a6 Tenm3 Opalin Cdh12 Enpp6 Kcng1 Cux2 Otof Rorb Rspo1 Sulf2 Fezf2 Osr1'.split() # some important genes that we interested
AnnData object with n_obs × n_vars = 13516 × 21158
obs: 'QC', 'batch', 'class_color', 'class_id', 'class_label', 'cluster_color', 'cluster_labels', 'dataset', 'date', 'ident', 'individual', 'nCount_RNA', 'nFeature_RNA', 'nGene', 'nUMI', 'project', 'region', 'species', 'subclass_id', 'cell_type', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'n_counts', 'subclass_label_R'
var: 'mt', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'n_cells', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'Marker', 'MERFISH_gene'
uns: 'hvg', 'neighbors', 'pca', 'rank_genes_groups', 'subclass_label_colors', 'umap'
obsm: 'X_pca', 'X_umap'
varm: 'PCs'
layers: 'logcounts'
obsp: 'connectivities', 'distances' AnnData object with n_obs × n_vars = 5381 × 254
obs: 'sample_id', 'slice_id', 'class_label', 'cell_type', 'label', 'x', 'y'
obsm: 'spatial'
In [4]:
sc_data.X.max(),st_data.X.max()
Out[4]:
(np.float32(7.5840874), np.float32(5.3889804))
Train SpateCV model to impute gene¶
After the model was trained here, only the gene imputation function was used.
In [5]:
enclus_model = enclus.ENCLUS(spatial_data = st_data, sc_data = sc_data,
num_layers=3,
num_neurons=1024,
latent_dim=512,
k_nearest=16,
num_cov_genes=64,
num_HVG=1024,
sc_genes=add_genes,
spatial_dist="pois",
sc_dist="nb",
spatial_coeff=1,
sc_coeff=1,
kl_coeff=0.03,
n_clusters=10,
tau=0.1,
gamma=0.1,
adaptive_weights=True,
early_stopping=True,
patience=30,
num_heads=10,
head_dim=168,
distance_metric='euclidean'
)
#train model
enclus_model.train(training_steps=4628,
batch_size=2048,
verbose=100,
init_lr=0.00001,
decay_steps=4000)
enclus_model.impute_genes()
st_data.obsm['enclus_latent'] = enclus_model.spatial_data.obsm['enclus_latent']
st_data.obsm['imputation'] = enclus_model.spatial_data.obsm['imputation']
sc_data.obsm['enclus_latent'] = enclus_model.sc_data.obsm['enclus_latent']
sc_data shape and st_data shape: (13516, 1938) (5381, 247) Initializing CVAE Finished Initializing ENCLUS Initializing cluster centers...
| spatial_w: 3.75 sc_w: 4.04 cov_w: 6.19 kl_w: 0.80 cluster_w: 1.05: 21%|██▏ | 986/4628 [2:40:29<9:52:47, 9.77s/it]
Early stopping triggered
Finished imputing missing gene for spatial data! See 'imputation' in obsm of ENCLUS.spatial_data
View the results of gene imputation¶
In [6]:
st_data.obsm['imputation']
Out[6]:
| index | 1700022I11Rik | 1810046K07Rik | 5730522E02Rik | Acta2 | Adam2 | Adamts2 | Adamts4 | Adra1b | Alk | Ankfn1 | ... | Tal1 | 1700047M11Rik | Tbx1 | Hepacam | Aplp1 | Slc2a1 | Flt3 | Ldb2 | Tnfrsf10b | Gm26522 |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 100119755510557417791056480683612014915 | 0.033958 | 0.016518 | 0.467514 | 0.055450 | 0.011491 | 0.022482 | 0.121522 | 0.276107 | 0.012841 | 0.001613 | ... | 0.006594 | 0.015793 | 0.006942 | 0.154774 | 2.848228 | 0.282262 | 0.036911 | 0.967486 | 0.009771 | 0.030900 |
| 100132293312011676013834246988384644937 | 0.017182 | 0.005288 | 0.038738 | 0.007840 | 0.005177 | 0.135361 | 0.324683 | 0.036284 | 0.018930 | 0.005991 | ... | 0.004786 | 1.213946 | 0.006967 | 1.754055 | 2.685546 | 0.339692 | 0.010658 | 0.077153 | 0.012750 | 0.003057 |
| 100141477907895285159644541629937136293 | 0.012541 | 0.033874 | 0.024705 | 0.016275 | 0.013458 | 0.007174 | 0.004996 | 0.047633 | 0.017084 | 0.101747 | ... | 0.013227 | 0.024183 | 0.023657 | 0.450264 | 1.117197 | 0.769400 | 0.044615 | 0.563372 | 0.010724 | 0.031875 |
| 100170194898756150593503685585478903524 | 0.025831 | 0.006050 | 0.124738 | 0.006358 | 0.005296 | 0.008292 | 0.010910 | 0.067310 | 1.180437 | 0.035100 | ... | 0.002178 | 0.006100 | 0.006054 | 0.055215 | 0.607525 | 0.071660 | 0.826093 | 0.144919 | 0.008019 | 0.013541 |
| 100221098919514132063709706431102588200 | 0.013862 | 0.007168 | 0.099484 | 0.007788 | 0.016363 | 0.045896 | 0.028138 | 0.792513 | 0.006974 | 0.002345 | ... | 0.009607 | 0.015191 | 0.023333 | 0.156937 | 1.545135 | 0.106255 | 0.033739 | 0.398914 | 0.003060 | 0.013868 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 99761023417729303996882966126171608555 | 0.021530 | 0.010730 | 0.064236 | 0.016746 | 0.065895 | 0.022191 | 0.009726 | 0.274346 | 0.006460 | 0.012128 | ... | 0.023330 | 0.018904 | 0.016180 | 0.116808 | 2.861451 | 0.355431 | 0.035919 | 1.282821 | 0.014298 | 0.019250 |
| 9976139147878378112755895136358908028 | 0.027440 | 0.009032 | 0.028702 | 0.003477 | 0.009816 | 0.006668 | 0.001906 | 0.101085 | 0.046885 | 0.008719 | ... | 0.008437 | 0.016010 | 0.005064 | 1.883570 | 0.878763 | 0.423456 | 0.006838 | 0.079814 | 0.011676 | 0.007929 |
| 99775860964963590854647880769830485687 | 0.014351 | 0.013585 | 0.039534 | 0.005633 | 0.023213 | 0.010681 | 0.005401 | 0.428636 | 0.013966 | 0.009738 | ... | 0.016675 | 0.009954 | 0.020660 | 0.229057 | 1.067936 | 0.623710 | 0.061817 | 0.729681 | 0.007573 | 0.012526 |
| 99800137893930745778774961383897495957 | 0.019881 | 0.025649 | 0.119920 | 0.028423 | 0.070167 | 0.079435 | 0.130550 | 1.304224 | 0.008272 | 0.002989 | ... | 0.007429 | 0.014031 | 0.029708 | 0.105470 | 3.212263 | 0.087636 | 0.035806 | 0.760579 | 0.007636 | 0.021172 |
| 99972994093375169663255672333494574767 | 0.017218 | 0.035808 | 0.046467 | 0.016681 | 0.023042 | 0.009058 | 0.003639 | 0.146057 | 0.024576 | 0.091198 | ... | 0.018043 | 0.016521 | 0.021321 | 0.333742 | 1.023737 | 0.742046 | 0.043975 | 0.830922 | 0.010382 | 0.032604 |
5381 rows × 1938 columns
View the latent representations of scRNA-seq and ST¶
Defining cell type color palette
In [ ]:
import umap.umap_ as umap
fit = umap.UMAP(
n_neighbors = 50,
min_dist = 0.5,
n_components = 2,
)
latent_umap = fit.fit_transform(np.concatenate([st_data.obsm['enclus_latent'], sc_data.obsm['enclus_latent']], axis = 0))
st_data.obsm['latent_umap'] = latent_umap[:st_data.shape[0]]
sc_data.obsm['latent_umap'] = latent_umap[st_data.shape[0]:]
lim_arr = np.concatenate([st_data.obsm['latent_umap'], sc_data.obsm['latent_umap']], axis = 0)
delta = 1
pre = 0.1
xmin = np.percentile(lim_arr[:, 0], pre) - delta
xmax = np.percentile(lim_arr[:, 0], 100 - pre) + delta
ymin = np.percentile(lim_arr[:, 1], pre) - delta
ymax = np.percentile(lim_arr[:, 1], 100 - pre) + delta
color_dict = {'Astro': '#1f77b4',
'Endo': '#aec7e8',
'L2/3 IT': '#ff7f0e',
'L5 ET': '#ffbb78',
'L5 IT': '#2ca02c',
'L5/6 NP': '#98df8a',
'L6 CT': '#d62728',
'L6 IT': '#ff9896',
'L6b': '#9467bd',
'Lamp5': '#c5b0d5',
'Micro': '#8c564b',
'Oligo': '#c49c94',
'OPC': '#e377c2',
'Peri': '#f7b6d2',
'Pvalb': '#7f7f7f',
'PVM': '#a3a2a2',
'Sncg': '#bcbd22',
'Sst': '#dbdb8d',
'Vip': '#17becf',
'VLMC': '#9edae5',
'None': '#dbd9d9',
'SMC': '#B87BCE',
'L6 IT Car3': '#82A8CE',
'L4/5 IT': '#2ca02c',
}
#20 cell types
labelnames = ['Astro', 'Endo', 'L2/3 IT', 'L5 IT', 'L5 ET', 'L5/6 NP', 'L6 IT', 'L6 CT', 'L6b', 'Lamp5', 'Micro', 'Oligo', 'OPC', 'Peri', 'Pvalb',
'PVM', 'Sncg', 'Sst', 'Vip', 'VLMC']
fig = plt.figure(figsize = (13,5))
plt.subplot(121)
sns.scatterplot(x = sc_data.obsm['latent_umap'][:, 0],
y = sc_data.obsm['latent_umap'][:, 1], hue = sc_data.obs['cell_type'], s = 8, palette = color_dict,
legend = False)
plt.title("snRNA-seq Latent")
plt.xlim([xmin, xmax])
plt.ylim([ymin, ymax])
plt.axis('off')
plt.subplot(122)
sns.scatterplot(x = st_data.obsm['latent_umap'][:, 0],
y = st_data.obsm['latent_umap'][:, 1], hue = st_data.obs['cell_type'], s = 8, palette = color_dict, legend = True)
legend = plt.legend(title = 'Cell Type', prop={'size': 12}, fontsize = '12', markerscale = 3, ncol = 2, bbox_to_anchor = (1.2, 1))#, loc = 'lower left')
plt.setp(legend.get_title(),fontsize='12')
plt.title("st_data Latent")
plt.axis('off')
plt.tight_layout()
plt.xlim([xmin, xmax])
plt.ylim([ymin, ymax])
plt.savefig('./Result/mop/latent.pdf')
# plt.show()
Calculate the cosine similarity of each cell¶
In [10]:
from sklearn.metrics.pairwise import cosine_similarity
sp_adata = st_data.copy()
generated_cells = st_data.obsm['imputation']
sp_adata_SS = sp_adata.copy()
overlaped_genes = np.array(list(set(sp_adata_SS.var.index) & set(generated_cells.columns)))
sp_adata_SS = sp_adata_SS.copy()[:, overlaped_genes]
generated_cells = generated_cells.loc[:, overlaped_genes].T
temp = pd.DataFrame()
temp['cosine similarity'] = list(np.diag(cosine_similarity(sp_adata_SS.X.copy().T, generated_cells.values)))
cosine = temp['cosine similarity'].mean()
print('cosine similarity:',temp['cosine similarity'].mean())
raw = sp_adata_SS.to_df()
impute = generated_cells.T
import scipy.stats as st
result = pd.DataFrame()
for label in raw.columns:
if label not in impute.columns:
spearmanr = 0
else:
raw_col = raw.loc[:, label]
impute_col = impute.loc[:, label]
impute_col = impute_col.fillna(1e-20)
raw_col = raw_col.fillna(1e-20)
spearmanr, _ = st.pearsonr(raw_col, impute_col)
pearsonr_df = pd.DataFrame(spearmanr, index=["PCC"], columns=[label])
result = pd.concat([result, pearsonr_df], axis=1)
print(result.median(axis=1))
cosine similarity: 0.6716988 PCC 0.622068 dtype: float32
Save results¶
In [14]:
sc_data.write('./datasets/mop/ENVI_Ref_snRNA_mop_qc3_2Kgenes.h5ad')
st_data = sc.read_h5ad('./datasets/mop/ENVI_mop.h5ad')
st_data.obsm['imputation'].to_csv('./Result/mop/SpateCV_impute.csv',header = 1, index = 1)
Downstream analysis¶
In [2]:
# sc_data = sc.read('./datasets/mop/ENVI_Ref_snRNA_mop_qc3_2Kgenes.h5ad')
st_data = sc.read('./datasets/mop/ENVI_mop.h5ad')
In [5]:
from matplotlib.gridspec import GridSpec
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
x=st_data.obsm['spatial'][:, 0]
y=-st_data.obsm['spatial'][:, 1]
sc_data = sc.read_h5ad('./datasets/mop/processed_snRNA_mop.h5ad')
markers = list(set(sc_data.var.loc[sc_data.var['highly_variable']==1].index))
add_genes = 'Nos1ap Erbb4 Atp2b4 Adamts3 Cdh4 Celf2 Crispld1 Esrrg Htr4 Kcnh5 Prkg1 3110035E14Rik Garnl3 Pvalb Cplx3 Fam84b Slc17a6 Tenm3 Opalin Cdh12 Enpp6 Kcng1 Cux2 Otof Rorb Rspo1 Sulf2 Fezf2 Osr1'.split() # some important genes that we interested
markers = np.unique(markers+add_genes)
print("Markers:",len(markers))
st_genes = st_data.to_df().columns.tolist()
merfish_genes = [gene for gene in markers if gene in st_genes]
nonmerfish_genes = [gene for gene in add_genes if gene not in merfish_genes]
print(len(merfish_genes),len(nonmerfish_genes))
Markers: 1938 247 13
Method comparison¶
We compared the accuracy of several methods in predicting 1938 marker genes
In [18]:
df_SpateCV = st_data.obsm['imputation']
df_ENVI = pd.read_csv('./Result/mop/ENVI_impute.csv')
df_SpaGE = pd.read_csv('./Result/mop/SpaGE_impute.csv')
df_Tangram = pd.read_csv('./Result/mop/Tangram_impute.csv')
df_gimVI = pd.read_csv('./Result/mop/gimVI_impute.csv')
df_stPlus = pd.read_csv('./Result/mop/stPlus_impute.csv')
We Selected seven marker genes with clear spatial expression patterns
In [6]:
def plot_predictGene_comparison(predicted_genes,
df_SpateCV,
df_ENVI,
df_SpaGE,
df_Tangram,
df_gimVI,
df_stPlus,
st_data,
title="MERFISH test genes"):
sns.set(style="white", context="paper", font_scale=2.5)
# ========== 1. Unify the color range ==========
# Ground Truth
df_truth = st_data.to_df()[predicted_genes] # (n_spots, n_genes)
actual_expression = np.log(df_truth.values + 0.1) # log transform
# ENVIC
envic_expression = np.log(df_SpateCV[predicted_genes].values + 0.1)
# ENVI
envi_expression = np.log(df_ENVI[predicted_genes].values + 0.1)
# SpaGE
spage_expression = np.log(df_SpaGE[predicted_genes].values + 0.1)
# Tangram
tangram_expression = np.log(df_Tangram[predicted_genes].values + 0.1)
# gimVI
gimVI_expression = np.log(df_gimVI[predicted_genes].values + 0.1)
# stPlus
stPlus_expression = np.log(df_stPlus[predicted_genes].values + 0.1)
# Merge all results and calculate vmin / vmax uniformly
combined_expression = np.concatenate(
[actual_expression, envic_expression, envi_expression,spage_expression, tangram_expression, gimVI_expression, stPlus_expression],
axis=0
)
vmin = np.percentile(combined_expression, 20)
vmax = np.percentile(combined_expression, 95)
# ========== 2. Set up the drawing grid (rows = number of genes, columns = number of methods + 1 for gene names) ==========
n_genes = len(predicted_genes)
n_methods = 7 # ground truth,ENVIC, SpaGE, Tangram, gimVI, stPlus
n_total_cols = n_methods + 1 # An additional column is used for gene names
# Set the column width ratio, with the first column slightly wider to accommodate gene names.
fig = plt.figure(figsize=(2*n_total_cols, 2*n_genes), dpi=300)
gs = GridSpec(n_genes, n_total_cols, figure=fig, wspace=0.1, hspace=0.01, width_ratios=[1] + [4]*n_methods)
# Method name, including gene name column
method_names = ['Gene', 'Ground Truth', 'SpateCV', 'ENVI','SpaGE', 'Tangram', 'gimVI', 'stPlus']
# ========== 3. Plot scatter plots gene by gene and column by column ==========
for row_idx, gene in enumerate(predicted_genes):
for col_idx, (df_, method_name) in enumerate(zip(
[None, df_truth, df_SpateCV,df_ENVI, df_SpaGE, df_Tangram, df_gimVI, df_stPlus],
['Gene', 'Ground Truth', 'SpateCV', 'ENVI','SpaGE', 'Tangram', 'gimVI', 'stPlus']
)):
ax = fig.add_subplot(gs[row_idx, col_idx])
if col_idx == 0:
# The leftmost column is used to display the gene names
ax.text(0.5, 0.5, gene, fontsize=16, ha='center', va='center', transform=ax.transAxes)
ax.axis('off')
elif col_idx == 1:
# Ground Truth
cvec = actual_expression[:, row_idx]
scatter = ax.scatter(
x_coords := st_data.obsm['spatial'][:, 0],
y_coords := -st_data.obsm['spatial'][:, 1],
c=cvec,
cmap='Reds',
vmin=vmin,
vmax=vmax,
s=5,
edgecolor='none',
alpha=0.8
)
if row_idx == 0:
ax.set_title("Ground Truth", fontsize=14, pad=10)
ax.set_aspect('equal')
ax.axis('off')
else:
# Draw scatter plots of the expression values of each method
cvec = np.log(df_[gene] + 0.1)
scatter = ax.scatter(
x_coords,
y_coords,
c=cvec,
cmap='Reds',
vmin=vmin,
vmax=vmax,
s=5,
edgecolor='none',
alpha=0.8
)
if row_idx == 0:
ax.set_title(method_name, fontsize=14, pad=10)
ax.set_aspect('equal')
ax.axis('off')
plt.savefig('./Result/mop/predictGeneCompare.pdf')
plt.show()
# compare_gene = [gene for gene in add_genes if gene in st_genes]
compare_gene = [
'Osr1',
'Otof',
'Slc17a6',
'Fam84b',
'Opalin',
'Cdh12',
'Fezf2',
]
plot_predictGene_comparison(
predicted_genes=compare_gene,
df_SpateCV=df_SpateCV,
df_ENVI=df_ENVI,
df_SpaGE=df_SpaGE,
df_Tangram=df_Tangram,
df_gimVI=df_gimVI,
df_stPlus=df_stPlus,
st_data=st_data,
title="MERFISH test genes"
)
Calculate the MAE between the predicted results and the true results¶
Through this 'calculate_mae' function, the MAE of compare_gene and merfish_gene can be calculated, and a bar chart can be drawn.
In [9]:
from sklearn.metrics import mean_absolute_error
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
def calculate_mae(predicted_genes, df_truth, df_methods, methods):
mae_results = pd.DataFrame(index=predicted_genes, columns=methods)
for method, df_pred in zip(methods, df_methods):
for gene in predicted_genes:
mae = mean_absolute_error(df_truth[gene], df_pred[gene])
mae_results.at[gene, method] = mae
mae_results = mae_results.astype(float)
# Calculate the average MAE for each method
average_mae = mae_results.mean().to_frame(name='Average MAE')
return mae_results, average_mae
def plot_mae(mae_results, average_mae, methods):
mae_long = mae_results.reset_index().melt(id_vars='index', value_vars=methods,
var_name='Method', value_name='MAE')
mae_long.rename(columns={'index': 'Gene'}, inplace=True)
plt.figure(figsize=(4, 3), dpi=300)
sns.barplot(data=mae_long, x='Gene', y='MAE', hue='Method', palette='Set3')
plt.title('MAE between Predicted Marker Genes and Ground Truth', fontsize=8)
plt.legend(title='Method', fontsize=4, title_fontsize=6)
plt.xticks(rotation=0, ha='center', fontsize=8)
plt.yticks(rotation=0, ha='right', fontsize=8)
plt.xlabel("")
plt.ylabel("")
sns.despine()
plt.tight_layout()
plt.savefig('./Result/mop/MAE-Compare.pdf')
plt.show()
plt.figure(figsize=(4,3), dpi=300)
sns.set(style="whitegrid", context="paper", font_scale=1.1)
ax = sns.barplot(
data=average_mae.reset_index(),
x='index',
y='Average MAE',
palette='Set3',
width=0.5
)
plt.title('Average MAE across Predicted Marker Genes', fontsize=10, pad=12)
plt.xticks(rotation=0, ha='center', fontsize=10)
plt.yticks(rotation=0, ha='right', fontsize=10)
plt.xlabel("")
plt.ylabel("")
# Remove the top and right borders
sns.despine()
plt.tight_layout()
plt.savefig('./Result/mop/Average-MAE-Compare.pdf', bbox_inches='tight')
# plt.savefig('./Result/mop/MERFISH-MAE-Gene.pdf')
plt.show()
methods = ['SpateCV', 'ENVI', 'stPlus', 'SpaGE', 'Tangram', 'gimVI']
df_methods = [df_SpateCV, df_ENVI, df_stPlus, df_SpaGE, df_Tangram, df_gimVI]
df_truth = st_data.to_df()[merfish_genes] # (n_spots, n_genes)
mae_results, average_mae = calculate_mae(compare_gene, df_truth, df_methods, methods)
plot_mae(mae_results, average_mae, methods)
Imputation of Non-MERFISH genes¶
In [10]:
def plot_predictNonGene_comparison(df_SpateCV, df_ENVI,df_SpaGE, df_Tangram, df_gimVI, df_stPlus, genes,
x, y, title='Non-MERFISH Gene Expression'):
sns.set(style="white", context="paper", font_scale=2.5)
all_expr_values = []
for df_ in [df_SpateCV, df_SpaGE, df_Tangram, df_gimVI, df_stPlus]:
if not set(genes).issubset(df_.columns):
missing_genes = set(genes) - set(df_.columns)
raise ValueError(f"The following genes do not exist in the DataFrame: {missing_genes}")
cvec_all = np.log(df_[genes].values + 0.1).flatten()
all_expr_values.append(cvec_all)
all_expr_values = np.concatenate(all_expr_values)
vmin = np.percentile(all_expr_values, 20)
vmax = np.percentile(all_expr_values, 95)
n_genes = len(genes)
n_methods = 6 # ENVIC, ENVI, SpaGE, Tangram, gimVI, stPlus
n_total_cols = n_methods + 1
fig = plt.figure(figsize=(2*n_total_cols, 2*n_genes), dpi=300)
gs = GridSpec(n_genes, n_total_cols, figure=fig, wspace=0.1, hspace=0.01, width_ratios=[1] + [4]*n_methods)
method_names = ['Gene', 'SpateCV', 'ENVI','SpaGE', 'Tangram', 'gimVI', 'stPlus']
for row_idx, gene in enumerate(genes):
for col_idx, (df_, method_name) in enumerate(zip(
[None, df_SpateCV, df_ENVI,df_SpaGE, df_Tangram, df_gimVI, df_stPlus], method_names
)):
ax = fig.add_subplot(gs[row_idx, col_idx])
if col_idx == 0:
ax.text(0.95, 0.5, gene, fontsize=16, ha='right', va='center', transform=ax.transAxes)
ax.axis('off')
else:
try:
cvec = np.log(df_[gene] + 0.1)
except KeyError:
raise KeyError(f"基因 '{gene}' 在方法 '{method_name}' 的DataFrame中不存在。")
scatter = ax.scatter(
x=x,
y=y,
c=cvec,
cmap='Reds',
vmin=vmin,
vmax=vmax,
s=5,
edgecolor='none',
alpha=0.8
)
if row_idx == 0:
ax.set_title(method_name, fontsize=14, pad=15)
ax.set_aspect('equal')
ax.axis('off')
# plt.savefig('./Result/mop/Non-MERFISHGene.pdf')
plt.savefig('./Result/mop/Astro_makerGene.pdf')
plt.show()
None_gene = [gene for gene in add_genes if gene not in st_genes]
genes = ['Gfap','Aqp4','Nfia','Hepacam','Nxn','Ptprz1','Gramd3'] # Astro DE genes
# Astro_maker = ['Nfia','Hepacam','Ptprz1']
plot_predictNonGene_comparison(
df_SpateCV=df_SpateCV,
df_ENVI=df_ENVI,
df_SpaGE=df_SpaGE,
df_Tangram=df_Tangram,
df_gimVI = df_gimVI,
df_stPlus = df_stPlus,
genes=genes,
x=x,
y=y,
title='Non-MERFISH Gene Expression'
)
All cell types
In [63]:
sns.set_context('paper',font_scale=2)
plt.subplots(figsize=(5,5),dpi=300)
sns.scatterplot(data=st_data.obs, x="x", y="y", hue="cell_type",hue_order=labelnames,s=15,palette=color_dict)
plt.legend(title = 'Cell Type', prop={'size': 12}, fontsize = '12', markerscale = 3, ncol = 2, bbox_to_anchor = (1, 1))
# plt.legend(bbox_to_anchor=(1.0,0.98), loc="upper left",framealpha=0,markerscale=1.5)
plt.gca().invert_yaxis()
plt.axis('off')
plt.savefig('./Result/mop/MOp-Cell_type.pdf', bbox_inches='tight')
plt.show()
In [19]:
sns.set_context('paper',font_scale=4.5)
plt.subplots(figsize=(5,5),dpi=300)
ex_neuronal = ['L2/3 IT', 'L5 IT', 'L5 ET', 'L5/6 NP', 'L6 IT', 'L6 CT', 'L6b']
p2_res = st_data.obs.copy()
# 检查 'subclass' 是否为分类类型,并添加 'None' 类别
if p2_res['cell_type'].dtype.name == 'category':
p2_res['cell_type'] = p2_res['cell_type'].cat.add_categories(['None'])
p2_res.loc[~p2_res['cell_type'].isin(ex_neuronal), 'cell_type'] = 'None'
sns.scatterplot(data=p2_res, x="x", y="y", hue="cell_type",hue_order=ex_neuronal,s=15,palette=color_dict)
plt.legend(bbox_to_anchor=(1,0.72), loc="upper left",fontsize = '12',framealpha=0,markerscale=3)
plt.gca().invert_yaxis()
plt.axis('off')
plt.savefig('./Result/mop/Celltype-layer.pdf', bbox_inches='tight')
plt.show()
In [22]:
sns.set_context('paper',font_scale=4.5)
plt.subplots(figsize=(5,5),dpi=300)
non_neuronal = ['Astro', 'Endo', 'Micro','OPC','Peri','PVM','VLMC','Oligo']
p2_res = st_data.obs.copy()
if p2_res['cell_type'].dtype.name == 'category':
p2_res['cell_type'] = p2_res['cell_type'].cat.add_categories(['None'])
p2_res.loc[~p2_res['cell_type'].isin(non_neuronal),'cell_type']='None'
sns.scatterplot(data=p2_res, x="x", y="y", hue="cell_type",hue_order=non_neuronal,s=15,palette=color_dict)
plt.legend(bbox_to_anchor=(1,0.72), loc="upper left",fontsize = '12',framealpha=0,markerscale=3)
plt.gca().invert_yaxis()
plt.axis('off')
plt.savefig('./Result/mop/Astro-celltype.pdf', bbox_inches='tight')
plt.show()
In [23]:
color_dict['None'] = '#e3e1e1'
sns.set_context('paper',font_scale=4.5)
plt.subplots(figsize=(5,5),dpi=150)
inh_neuronal = ['Lamp5', 'Sncg', 'Vip', 'Sst', 'Pvalb']
p2_res = st_data.obs.copy()
if p2_res['cell_type'].dtype.name == 'category':
p2_res['cell_type'] = p2_res['cell_type'].cat.add_categories(['None'])
p2_res.loc[~p2_res['cell_type'].isin(inh_neuronal),'cell_type']='None'
sns.scatterplot(data=p2_res, x="x", y="y", hue="cell_type",hue_order=inh_neuronal,s=15,palette=color_dict)
plt.legend(bbox_to_anchor=(.92,0.82), loc="upper left",framealpha=0,markerscale=4.5)
plt.gca().invert_yaxis()
plt.axis('off')
plt.show()